{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SslnlBaKPeWP"
   },
   "source": [
    "# Wikipedia Semantic Search with Cohere Embedding Archives\n",
    "This notebook contains the starter code to do simple [semantic search](https://txt.cohere.ai/what-is-semantic-search/) on the [Wikipedia embeddings archives](https://txt.cohere.ai/embedding-archives-wikipedia/) published by Cohere. These archives embed Wikipedia sites in multiple languages. In this example, we'll use [Wikipedia Simple English](https://huggingface.co/datasets/Cohere/wikipedia-22-12-simple-embeddings). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IUnwp2cYNnP0"
   },
   "outputs": [],
   "source": [
    "# Let's install \"cohere<5\" and HF datasets\n",
    "# TODO: upgrade to \"cohere>5\"",
"!pip install \"cohere<5\" datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hZds1apHPsag"
   },
   "source": [
    "Let's now download 1,000 records from the English Wikipedia embeddings archive so we can search it afterwards."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "v8Pogz7gPQwg"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5863371d28d14ce0bf0cb80643c66d21",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/1.29k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration Cohere--wikipedia-22-12-simple-embeddings-94deea3d55a22093\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "import torch\n",
    "import cohere\n",
    "\n",
    "# Add your cohere API key from www.cohere.com\n",
    "co = cohere.Client(\"\")  \n",
    "\n",
    "#Load at max 1000 documents + embeddings\n",
    "max_docs = 1000\n",
    "docs_stream = load_dataset(f\"Cohere/wikipedia-22-12-simple-embeddings\", split=\"train\", streaming=True)\n",
    "\n",
    "docs = []\n",
    "doc_embeddings = []\n",
    "\n",
    "for doc in docs_stream:\n",
    "    docs.append(doc)\n",
    "    doc_embeddings.append(doc['emb'])\n",
    "    if len(docs) >= max_docs:\n",
    "        break\n",
    "\n",
    "doc_embeddings = torch.tensor(doc_embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VIlx5RVCP7g7"
   },
   "source": [
    "Now, `doc_embeddings` holds the embeddings of the first 1,000 documents in the dataset. Each document is represented as an [embeddings vector](https://txt.cohere.ai/sentence-word-embeddings/) of 768 values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "OBa3oxSsP2fv",
    "outputId": "d9d71135-7ac3-4424-d806-2a994a0b456a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1000, 768])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "doc_embeddings.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GbYAXaI4RQiH"
   },
   "source": [
    "We can now search these vectors for any query we want. For this toy example, we'll ask a question about Wikipedia since we know the Wikipedia page is included in the first 1000 documents we used here.\n",
    "\n",
    "To search, we embed the query, then get the nearest neighbors to its embedding (using dot product)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "SJGUurziNiYR",
    "outputId": "bb66def9-3d83-46f7-c871-1224eb5714cd"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: Who founded Wikipedia\n",
      "Wikipedia\n",
      "Larry Sanger and Jimmy Wales are the ones who started Wikipedia. Wales is credited with defining the goals of the project. Sanger created the strategy of using a wiki to reach Wales' goal. On January 10, 2001, Larry Sanger proposed on the Nupedia mailing list to create a wiki as a \"feeder\" project for Nupedia. Wikipedia was launched on January 15, 2001. It was launched as an English-language edition at www.wikipedia.com, and announced by Sanger on the Nupedia mailing list. Wikipedia's policy of \"neutral point-of-view\" was enforced in its initial months, and was similar to Nupedia's earlier \"nonbiased\" policy. Otherwise, there weren't very many rules initially, and Wikipedia operated independently of Nupedia. \n",
      "\n",
      "Wikipedia\n",
      "Wikipedia began as a related project for Nupedia. Nupedia was a free English-language online encyclopedia project. Nupedia's articles were written and owned by Bomis, Inc which was a web portal company. The important people of the company were Jimmy Wales, the person in charge of Bomis, and Larry Sanger, the editor-in-chief of Nupedia. Nupedia was first licensed under the Nupedia Open Content License which was changed to the GNU Free Documentation License before Wikipedia was founded and made their first article when Richard Stallman requested them. \n",
      "\n",
      "Wikipedia\n",
      "Wikipedia was started on January 10, 2001, by Jimmy Wales and Larry Sanger as part of an earlier online encyclopedia named Nupedia. On January 15, 2001, Wikipedia became a separate website of its own. It is a wiki that uses the software MediaWiki (like all other Wikimedia Foundation projects). \n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Get the query, then embed it\n",
    "query = 'Who founded Wikipedia'\n",
    "response = co.embed(texts=[query], model='multilingual-22-12')\n",
    "query_embedding = response.embeddings \n",
    "query_embedding = torch.tensor(query_embedding)\n",
    "\n",
    "# Compute dot score between query embedding and document embeddings\n",
    "dot_scores = torch.mm(query_embedding, doc_embeddings.transpose(0, 1))\n",
    "top_k = torch.topk(dot_scores, k=3)\n",
    "\n",
    "# Print results\n",
    "print(\"Query:\", query)\n",
    "for doc_id in top_k.indices[0].tolist():\n",
    "    print(docs[doc_id]['title'])\n",
    "    print(docs[doc_id]['text'], \"\\n\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WWFroOO2RwMd"
   },
   "source": [
    "This shows the top three passages that are relevant to the query. We can retrieve more results by changing the `k` value. The question in this simple demo is about Wikipedia because we know that the Wikipedia page is part of the documents in this subset of the archive."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
